#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Provides visualization utilities for analyzing cross-century interaction
patterns between fields of human activity in the Wikipedia biographies
network. The module filters directional connection data, builds century-by-
century matrices, and renders KDE heatmaps to reveal how knowledge flows
across historical periods and disciplines.
'''

import os
from typing import Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def get_tick_values(df: pd.DataFrame) -> np.ndarray:
    """
    compute percentile-based tick values from the non-zero entries of a dataframe.

    the function flattens the dataframe values, removes zeros, and returns
    selected percentiles. this is useful for choosing meaningful tick marks
    or thresholds in visualizations.

    parameters
    ----------
    df:
        input dataframe whose values will be used to compute percentiles.

    returns
    -------
    np.ndarray
        array of percentile values corresponding to [1, 10, 25, 50, 75, 90, 99, 100].
    """
    non_zero_values = df.values.flatten()
    non_zero_values = non_zero_values[non_zero_values > 0]
    percentiles = [1, 10, 25, 50, 75, 90, 99, 100]
    return np.percentile(non_zero_values, percentiles)


def remap_century(c: int) -> int:
    """
    remap century values to a 0-based scale for continuous bce/ce plotting.

    the mapping shifts ce centuries so that the first ce century is placed
    at 0, while bce centuries are left unchanged. this produces a continuous
    numeric axis that does not skip over the year 0.

    parameters
    ----------
    c:
        integer century label, negative for bce, positive for ce.

    returns
    -------
    int
        remapped century value such that 1 ce → 0, 2 ce → 1, etc., and bce
        centuries are unchanged.
    """
    return c - 1 if c > 0 else c


def filter_category(
    df: pd.DataFrame,
    category1_label: Optional[str] = None,
    category2_label: Optional[str] = None,
    category3_label: Optional[str] = None,
) -> pd.DataFrame:
    """
    filter an interaction dataframe to retain edges for specific source and target categories.

    the function always filters rows where the source category matches
    `category1_label`, and optionally keeps only those rows whose target
    category matches one or two additional category labels.

    parameters
    ----------
    df:
        input dataframe containing at least the columns 'Source_Category'
        and 'Target_Category'.
    category1_label:
        label of the source category to retain. this parameter is mandatory.
    category2_label:
        optional label of a first target category to retain.
    category3_label:
        optional label of a second target category to retain.

    returns
    -------
    pd.DataFrame
        dataframe filtered to the specified source and target categories.

    raises
    ------
    ValueError
        if `category1_label` is not provided.
    """
    if not category1_label:
        raise ValueError("category1_label is mandatory")

    # filter by required source category
    df_filtered = df[df["Source_Category"] == category1_label]

    # optionally restrict to one or two target categories
    if category2_label or category3_label:
        target_mask = pd.Series(False, index=df_filtered.index, dtype=bool)
        if category2_label:
            target_mask |= df_filtered["Target_Category"] == category2_label
        if category3_label:
            target_mask |= df_filtered["Target_Category"] == category3_label
        df_filtered = df_filtered[target_mask]

    return df_filtered


def visualise_interactions(
    df: pd.DataFrame,
    category1_label: str,
    category2_label: str,
    direction: str,
    visualize: bool = True,
    save_figure: bool = False,
) -> None:
    """
    visualise the intensity of inter-century interactions between two categories.

    the function:
    - filters the input dataframe to retain interactions where the source
      category is `category1_label` and the target category is `category2_label`
    - aggregates normalized connection counts into a century-by-century matrix
    - remaps centuries into a continuous numeric scale spanning bce and ce
    - generates a 2d kernel density estimate (kde) heatmap of connection density
      over source and target centuries

    parameters
    ----------
    df:
        input dataframe with at least the following columns:
        - 'Source_Category'
        - 'Target_Category'
        - 'Source_Century'
        - 'Target_Century'
        - 'n_connections_norm_{direction}' (where {direction} is the `direction` argument).
    category1_label:
        label of the source category of interest (e.g. "Science").
    category2_label:
        label of the target category of interest (e.g. "Politics").
    direction:
        suffix that selects which normalized connection column to use, for
        example "row" or "col", leading to columns such as
        'n_connections_norm_row'.
    visualize:
        if True, create and display the kde plot. if False, the function
        only prepares the data and performs no plotting.
    save_figure:
        if True, save the resulting figure to the
        'figures/connectivity_analysis/' directory. the file name is based
        on the categories and direction.

    returns
    -------
    None
        the function is called for its side effects (filtering and plotting).
    """
    # filter to the chosen source and target categories
    df = filter_category(df, category1_label, category2_label)

    # build a source-century by target-century matrix of normalized connections
    df_pivot = df.pivot_table(
        columns="Source_Century",
        index="Target_Century",
        values=f"n_connections_norm_{direction}",
        fill_value=0,
    )

    # ensure the matrix is square over the full set of centuries present
    all_centuries = sorted(set(df_pivot.index).union(df_pivot.columns))
    df_pivot = df_pivot.reindex(index=all_centuries, columns=all_centuries, fill_value=0)

    # convert to long format for kde plotting
    df_pivot_long = df_pivot.stack(future_stack=True).reset_index(
        name=f"n_connections_norm_{direction}"
    )

    # add remapped positions for continuous bce/ce axes
    df_pivot_long["Source_Pos"] = df_pivot_long["Source_Century"].apply(remap_century)
    df_pivot_long["Target_Pos"] = df_pivot_long["Target_Century"].apply(remap_century)

    # keep track of which original centuries appear on the axes
    original_centuries = np.sort(
        np.unique(
            np.concatenate(
                [
                    df_pivot_long["Source_Century"].to_numpy(),
                    df_pivot_long["Target_Century"].to_numpy(),
                ]
            )
        )
    )
    mapped_ticks = np.array([remap_century(c) for c in original_centuries])

    if visualize:
        sns.set_context("paper", font_scale=2)
        fig, ax = plt.subplots(figsize=(14, 10))

        # create a kde heatmap where density is weighted by normalized connection counts
        sns.kdeplot(
            data=df_pivot_long,
            x="Target_Pos",  # x-axis: target century
            y="Source_Pos",  # y-axis: source century
            weights=f"n_connections_norm_{direction}",
            fill=True,
            thresh=0,
            cmap="cubehelix",
            levels=15,
            bw_adjust=0.25,
            cbar=True,
            cbar_kws={"label": "Connection Density"},
            ax=ax,
        )

        # update the colorbar label
        cbar = ax.collections[0].colorbar
        cbar.ax.tick_params()
        cbar.set_label("Connection Density")

        # define tick labels: centuries -5..-1 and 1..20 (skipping 0)
        desired_centuries = list(range(-5, 0)) + list(range(1, 21))
        tick_positions = [remap_century(c) for c in desired_centuries]
        tick_labels = [str(c) for c in desired_centuries]

        # set axis limits in the remapped space
        min_pos = remap_century(-5)
        max_pos = remap_century(20)

        ax.set_xlim(min_pos, max_pos)
        ax.set_ylim(min_pos, max_pos)

        # apply ticks at remapped positions but show original century labels
        ax.set_xticks(tick_positions)
        ax.set_xticklabels(tick_labels)
        ax.set_yticks(tick_positions)
        ax.set_yticklabels(tick_labels)

        # add a diagonal reference line (same century on both axes)
        ax.plot(
            [min_pos, max_pos],
            [min_pos, max_pos],
            linestyle="--",
            linewidth=1,
            color="white",
        )

        # configure titles and labels
        ax.set_title(
            f"{category1_label} → {category2_label} connections over time\n"
            f"({direction.capitalize()} normalized)",
            pad=20,
        )
        ax.set_xlabel(f"Target century ({category2_label})")
        ax.set_ylabel(f"Source century ({category1_label})")

        # add a light grid to help read the plot
        ax.grid(True, alpha=0.3)

        plt.tight_layout()

        if save_figure:
            save_dir = "figures/connectivity_analysis/"
            os.makedirs(save_dir, exist_ok=True)
            filename = (
                f"conns_{category1_label}_{category2_label}_{direction}.png"
                .replace(" ", "_")
                .replace("&", "")
                .lower()
            )
            plt.savefig(
                os.path.join(save_dir, filename),
                dpi=500,
                bbox_inches="tight",
                transparent=True,
            )

        plt.show()